# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""For loading data into NMT models."""
from __future__ import print_function

import collections

import tensorflow as tf

from ..utils import vocab_utils


__all__ = ["BatchedInput", "get_iterator", "get_infer_iterator"]


# NOTE(ebrevdo): When we subclass this, instances' __dict__ becomes empty.
class BatchedInput(
    collections.namedtuple(
        "BatchedInput",
        (
            "initializer",
            "source",
            "target_input",
            "target_output",
            "source_sequence_length",
            "target_sequence_length",
        ),
    )
):
    pass


def get_infer_iterator(
    src_dataset,
    src_vocab_table,
    batch_size,
    eos,
    src_max_len=None,
    use_char_encode=False,
):
    if use_char_encode:
        src_eos_id = vocab_utils.EOS_CHAR_ID
    else:
        src_eos_id = tf.cast(
            src_vocab_table.lookup(
                tf.constant(eos)), tf.int32)
    src_dataset = src_dataset.map(lambda src: tf.string_split([src]).values)

    if src_max_len:
        src_dataset = src_dataset.map(lambda src: src[:src_max_len])

    if use_char_encode:
        # Convert the word strings to character ids
        src_dataset = src_dataset.map(
            lambda src: tf.reshape(vocab_utils.tokens_to_bytes(src), [-1])
        )
    else:
        # Convert the word strings to ids
        src_dataset = src_dataset.map(
            lambda src: tf.cast(src_vocab_table.lookup(src), tf.int32)
        )

    # Add in the word counts.
    if use_char_encode:
        src_dataset = src_dataset.map(
            lambda src: (
                src,
                tf.to_int32(tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN),
            )
        )
    else:
        src_dataset = src_dataset.map(lambda src: (src, tf.size(src)))

    def batching_func(x):
        return x.padded_batch(
            batch_size,
            # The entry is the source line rows;
            # this has unknown-length vectors.  The last entry is
            # the source row size; this is a scalar.
            padded_shapes=(
                tf.TensorShape([None]),  # src
                tf.TensorShape([]),
            ),  # src_len
            # Pad the source sequences with eos tokens.
            # (Though notice we don't generally need to do this since
            # later on we will be masking out calculations past the true
            # sequence.
            padding_values=(src_eos_id, 0),  # src
        )  # src_len -- unused

    batched_dataset = batching_func(src_dataset)
    batched_iter = batched_dataset.make_initializable_iterator()
    (src_ids, src_seq_len) = batched_iter.get_next()
    return BatchedInput(
        initializer=batched_iter.initializer,
        source=src_ids,
        target_input=None,
        target_output=None,
        source_sequence_length=src_seq_len,
        target_sequence_length=None,
    )


def get_iterator(
    src_dataset,
    tgt_dataset,
    src_vocab_table,
    tgt_vocab_table,
    batch_size,
    sos,
    eos,
    random_seed,
    num_buckets,
    src_max_len=None,
    tgt_max_len=None,
    num_parallel_calls=4,
    output_buffer_size=None,
    skip_count=None,
    num_shards=1,
    shard_index=0,
    reshuffle_each_iteration=True,
    use_char_encode=False,
):
    if not output_buffer_size:
        output_buffer_size = batch_size * 1000

    if use_char_encode:
        src_eos_id = vocab_utils.EOS_CHAR_ID
    else:
        src_eos_id = tf.cast(
            src_vocab_table.lookup(
                tf.constant(eos)), tf.int32)

    tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32)
    tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32)

    src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))

    src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index)
    if skip_count is not None:
        src_tgt_dataset = src_tgt_dataset.skip(skip_count)

    src_tgt_dataset = src_tgt_dataset.shuffle(
        output_buffer_size, random_seed, reshuffle_each_iteration
    )

    src_tgt_dataset = src_tgt_dataset.map(
        lambda src, tgt: (
            tf.string_split(
                [src]).values, tf.string_split(
                [tgt]).values),
        num_parallel_calls=num_parallel_calls,
    ).prefetch(output_buffer_size)

    # Filter zero length input sequences.
    src_tgt_dataset = src_tgt_dataset.filter(
        lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0)
    )

    if src_max_len:
        src_tgt_dataset = src_tgt_dataset.map(
            lambda src, tgt: (src[:src_max_len], tgt),
            num_parallel_calls=num_parallel_calls,
        ).prefetch(output_buffer_size)
    if tgt_max_len:
        src_tgt_dataset = src_tgt_dataset.map(
            lambda src, tgt: (src, tgt[:tgt_max_len]),
            num_parallel_calls=num_parallel_calls,
        ).prefetch(output_buffer_size)

    # Convert the word strings to ids.  Word strings that are not in the
    # vocab get the lookup table's default_value integer.
    if use_char_encode:
        src_tgt_dataset = src_tgt_dataset.map(
            lambda src, tgt: (
                tf.reshape(vocab_utils.tokens_to_bytes(src), [-1]),
                tf.cast(tgt_vocab_table.lookup(tgt), tf.int32),
            ),
            num_parallel_calls=num_parallel_calls,
        )
    else:
        src_tgt_dataset = src_tgt_dataset.map(
            lambda src, tgt: (
                tf.cast(src_vocab_table.lookup(src), tf.int32),
                tf.cast(tgt_vocab_table.lookup(tgt), tf.int32),
            ),
            num_parallel_calls=num_parallel_calls,
        )

    src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size)
    # Create a tgt_input prefixed with <sos> and a tgt_output suffixed with
    # <eos>.
    src_tgt_dataset = src_tgt_dataset.map(
        lambda src, tgt: (
            src,
            tf.concat(([tgt_sos_id], tgt), 0),
            tf.concat((tgt, [tgt_eos_id]), 0),
        ),
        num_parallel_calls=num_parallel_calls,
    ).prefetch(output_buffer_size)
    # Add in sequence lengths.
    if use_char_encode:
        src_tgt_dataset = src_tgt_dataset.map(
            lambda src, tgt_in, tgt_out: (
                src,
                tgt_in,
                tgt_out,
                tf.to_int32(tf.size(src) / vocab_utils.DEFAULT_CHAR_MAXLEN),
                tf.size(tgt_in),
            ),
            num_parallel_calls=num_parallel_calls,
        )
    else:
        src_tgt_dataset = src_tgt_dataset.map(
            lambda src, tgt_in, tgt_out: (
                src,
                tgt_in,
                tgt_out,
                tf.size(src),
                tf.size(tgt_in),
            ),
            num_parallel_calls=num_parallel_calls,
        )

    src_tgt_dataset = src_tgt_dataset.prefetch(output_buffer_size)

    # Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...)
    def batching_func(x):
        return x.padded_batch(
            batch_size,
            # The first three entries are the source and target line rows;
            # these have unknown-length vectors.  The last two entries are
            # the source and target row sizes; these are scalars.
            padded_shapes=(
                tf.TensorShape([None]),  # src
                tf.TensorShape([None]),  # tgt_input
                tf.TensorShape([None]),  # tgt_output
                tf.TensorShape([]),  # src_len
                tf.TensorShape([]),
            ),  # tgt_len
            # Pad the source and target sequences with eos tokens.
            # (Though notice we don't generally need to do this since
            # later on we will be masking out calculations past the true
            # sequence.
            padding_values=(
                src_eos_id,  # src
                tgt_eos_id,  # tgt_input
                tgt_eos_id,  # tgt_output
                0,  # src_len -- unused
                0,
            ),
        )  # tgt_len -- unused

    if num_buckets > 1:

        def key_func(unused_1, unused_2, unused_3, src_len, tgt_len):
            # Calculate bucket_width by maximum source sequence length.
            # Pairs with length [0, bucket_width) go to bucket 0, length
            # [bucket_width, 2 * bucket_width) go to bucket 1, etc.  Pairs with length
            # over ((num_bucket-1) * bucket_width) words all go into the last
            # bucket.
            if src_max_len:
                bucket_width = (src_max_len + num_buckets - 1) // num_buckets
            else:
                bucket_width = 10

            # Bucket sentence pairs by the length of their source sentence and target
            # sentence.
            bucket_id = tf.maximum(
                src_len // bucket_width,
                tgt_len // bucket_width)
            return tf.to_int64(tf.minimum(num_buckets, bucket_id))

        def reduce_func(unused_key, windowed_data):
            return batching_func(windowed_data)

        batched_dataset = src_tgt_dataset.apply(
            tf.contrib.data.group_by_window(
                key_func=key_func, reduce_func=reduce_func, window_size=batch_size
            )
        )

    else:
        batched_dataset = batching_func(src_tgt_dataset)
    batched_iter = batched_dataset.make_initializable_iterator()
    (src_ids, tgt_input_ids, tgt_output_ids, src_seq_len, tgt_seq_len) = (
        batched_iter.get_next()
    )
    return BatchedInput(
        initializer=batched_iter.initializer,
        source=src_ids,
        target_input=tgt_input_ids,
        target_output=tgt_output_ids,
        source_sequence_length=src_seq_len,
        target_sequence_length=tgt_seq_len,
    )
